# Run a bunch of SMAC baselines
import os
from search_spaces.brax_env import Brax
import pickle
from smac_baselines import smac_params
import argparse
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.05'
parser = argparse.ArgumentParser()
parser.add_argument('-k', '--keys', nargs='+', default=None)
parser.add_argument('--seeds', nargs='+', type=int, default=[0, 1, 2, 3, 100, 200, 300])
parser.add_argument('-mp', '--max_parallel', type=int, default=1, )
parser.add_argument('--anneal_lr', action='store_true', help='whether use a manual cosine learning rate annealing schedule')
args = parser.parse_args()
print(vars(args))

seeds = args.seeds
if args.keys is not None:
    keys = [k for k in args.keys if k in smac_params.keys()]
else:
    keys = list(smac_params.keys())
print(f'All matching keys: {keys}')

for key in keys:
    print(f'Running {key}')
    env = key.split('_')[0]
    if not os.path.exists(f'data/baselines/{env}/{key}'):
        os.makedirs(f'data/baselines/{env}/{key}')

    ss = Brax(log_dir=f'data/baselines/{env}/{key}', env_name=env, max_parallel=args.max_parallel,
              do_nas='no_nas' not in key)
    cs = ss.config_space

    # default brax hp for PPO according to Brax paper in the appendix
    config = cs.get_default_configuration()
    param = smac_params[key]
    for p in param:
        config[p] = param[p]
    print(f'Current config={config}')
    ckpt_paths = [f'data/baselines/{env}/{key}/seed_{i}.pt' for i in args.seeds]

    trajs = ss.train_batch([config] * len(seeds),
                           seeds=seeds, nums_timesteps=[int(200e6)] * len(seeds),
                           checkpoint_paths=ckpt_paths,
                           max_parallel=args.max_parallel,
                           anneal_lr=args.anneal_lr)

    pickle.dump(trajs, open(f'data/baselines/{env}/{key}/trajs.pickle', 'wb'))
